Skip to content

Add unified encoder pytorch implementation #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 135 commits into
base: main
Choose a base branch
from

Conversation

CeliaBenquet
Copy link
Member

@CeliaBenquet CeliaBenquet commented May 1, 2025

This PR adds a PyTorch implementation of a unified CEBRA encoder, which is composed of:

  • A new sampling scheme that samples across all sessions so that they can be aligned on the neuron axis to train a single encoder.
  • A unified Dataset and Loader, adapted to the new sampling scheme.
  • A unified Solver that considers multiple sessions to be aligned at inference.
  • A new masked modeling training option, with different types of masking.

🚧 A preprint is pending "Unified CEBRA Encoders for Integrating Neural Recordings via Behavioral Alignment" by Célia Benquet, Hossein Mirzaei, Steffen Schneider, Mackenzie W. Mathis.

@CeliaBenquet CeliaBenquet requested review from stes and MMathisLab May 20, 2025 10:51
@MMathisLab MMathisLab changed the base branch from batched-inference-and-padding to main May 23, 2025 13:39
positive=self[index.positive],
negative=self[index.negative],
reference=self[index.reference],
positive=self.apply_mask(self[index.positive]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick sanity check; this is backwards compatable? @CeliaBenquet

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward compatible yes, I added a check on if the function doesn't exist for people who might want to use the adapt functionality on an older model, good catch.

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @CeliaBenquet ! I went through and left comments for disucssion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we put this into integrations vs. models? Models to me is encoders only cc @stes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently have some decoders here, although these are sklearn specific.

I think this module here is fine, at least right now I dont see a better place in the codebase to put them in. An argument to leave them here would be that they are an "extension" of the encoders we train, plus they are "raw" torch objects, which we currently all collected in cebra.models.

I dont have a strong opinion, just don't see where they would fit better... In integrations, we currently have only "standalone" helper functions, which these aren't.

@CeliaBenquet where are these decoders used around the codebase? and how are they trained?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I see

@@ -0,0 +1,38 @@
import torch.nn as nn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not have decoders somewhere like integrations? to me model is the encoders only cc @stes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this :D

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall; left some comments!

  • Implementation of the Mixin class for the masking: If I understood correctly, the only change is that this apply_mask function is applied after loading a batch. This seems to be a change that could be minimally invasively applied not in the dataset, but actually in the data loader. Is there a good case why the datasets themselves need to be modified?
  • Discussion on where to place the decoders: currently in cebra.models.decoders; are the decoders useful as "standalone" models? where are they currently used? based on that we could determine if we move them e.g. as standalone to integrations
  • see other comments; mostly on class design, removing duplicated code, etc.

Comment on lines +100 to +123
if hasattr(self, "apply_mask"):
batch = [
cebra_data.Batch(
reference=self.apply_mask(
session[index.reference[session_id]]),
positive=self.apply_mask(
session[index.positive[session_id]]),
negative=self.apply_mask(
session[index.negative[session_id]]),
index=index.index,
index_reversed=index.index_reversed,
) for session_id, session in enumerate(self.iter_sessions())
]
else:
batch = [
cebra_data.Batch(
reference=session[index.reference[session_id]],
positive=session[index.positive[session_id]],
negative=session[index.negative[session_id]],
index=index.index,
index_reversed=index.index_reversed,
) for session_id, session in enumerate(self.iter_sessions())
]
return batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we convert this if/else statement into a subclass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a backward compatibility check for old models, I don't know if it's worth it no? ideally we don't have it I added after Mackenzie's comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under which circumstance would the apply_mask function be missing?

Comment on lines +67 to +80
if hasattr(self, "apply_mask"):
# If the dataset has a mask, apply it to the data.
batch = Batch(
positive=self.apply_mask(self[index.positive]),
negative=self.apply_mask(self[index.negative]),
reference=self.apply_mask(self[index.reference]),
)
else:
batch = Batch(
positive=self[index.positive],
negative=self[index.negative],
reference=self[index.reference],
)
return batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above; a better way to implement this is by having the masking simply override the load_batch function, vs. introducing this if/else logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a backward compatibility check for old models, I don't know if it's worth it no? ideally we don't have it I added after Mackenzie's comment

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Just the one comment on kwargs seems critical to decide

@MMathisLab MMathisLab requested a review from stes May 28, 2025 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants